package database

import (
	"errors"
	"fmt"
	"gitee.com/zhucheer/orange/cfg"
	"gitee.com/zhucheer/orange/logger"
	"gitee.com/zhucheer/orange/queue"
	"github.com/gomodule/redigo/redis"
	"github.com/zhuCheer/pool"
	"sync"
	"time"
)

var redisConn *RedisDB

type RedisDB struct {
	connPool map[string]pool.Pool
	connList *queue.Queue
	count    int
	lock     sync.Mutex
}

// NewRedis 初始化 redis 连接
func NewRedis() DataBase {
	if redisConn != nil {
		return redisConn
	}

	redisConn = &RedisDB{
		connPool: make(map[string]pool.Pool, 0),
		connList: queue.NewQueue(),
	}
	return redisConn
}

// 注册所有已配置的 redis
func (re *RedisDB) RegisterAll() {
	redisConfig := cfg.Config.GetMap("database.redis")

	re.count = len(redisConfig)
	for dd := range redisConfig {
		re.Register(dd)
	}
}

// RegisterRedis 注册一个redis配置
func (re *RedisDB) Register(name string) {
	addr := cfg.Config.GetString("database.redis." + name + ".addr")
	password := cfg.Config.GetString("database.redis." + name + ".password")
	dbnum := cfg.Config.GetInt("database.redis." + name + ".dbnum")

	initCap := getDBIntConfig("redis", name, "initCap")
	maxCap := getDBIntConfig("redis", name, "maxCap")
	idleTimeout := getDBIntConfig("redis", name, "idleTimeout")

	// connRedis 建立连接
	connRedis := func() (interface{}, error) {
		conn, err := redis.Dial("tcp", addr)
		if err != nil {
			return nil, err
		}
		if password != "" {
			_, err := conn.Do("AUTH", password)
			if err != nil {
				return nil, err
			}
		}
		if dbnum > 0 {
			_, err := conn.Do("SELECT", dbnum)
			if err != nil {
				return nil, err
			}
		}
		return conn, err
	}

	// closeRedis 关闭连接
	closeRedis := func(v interface{}) error {
		return v.(redis.Conn).Close()
	}

	// pingRedis 检测连接连通性
	pingRedis := func(v interface{}) error {
		conn := v.(redis.Conn)

		val, err := redis.String(conn.Do("PING"))

		if err != nil {
			return err
		}
		if val != "PONG" {
			return errors.New("redis ping is error ping => " + val)
		}

		return nil
	}

	//创建一个连接池： 初始化5，最大连接30
	p, err := pool.NewChannelPool(&pool.Config{
		InitialCap: initCap,
		MaxCap:     maxCap,
		Factory:    connRedis,
		Close:      closeRedis,
		Ping:       pingRedis,
		//连接最大空闲时间，超过该时间的连接 将会关闭，可避免空闲时连接EOF，自动失效的问题
		IdleTimeout: time.Duration(idleTimeout) * time.Second,
	})
	if err != nil {
		logger.Error("register redis conn [%s] error:%v", name, err)
		return
	}
	re.insertPool(name, p)

}

// insertPool 将连接池插入map
func (re *RedisDB) insertPool(name string, p pool.Pool) {
	if re.connPool == nil {
		re.connPool = make(map[string]pool.Pool, 0)
	}

	re.lock.Lock()
	defer re.lock.Unlock()
	re.connPool[name] = p
}

// getDB 从连接池获取一个连接
func (re *RedisDB) getDB(name string) (conn interface{},put func(), err error) {
	put = func() {}
	if _, ok := re.connPool[name]; !ok {
		return nil,put, errors.New("no redis connect")
	}
	conn, err = re.connPool[name].Get()
	if err != nil {
		return nil,put, errors.New(fmt.Sprintf("redis get connect err:%v", err))
	}
	put = func() {
		re.connPool[name].Put(conn)
	}

	return conn, put,nil
}

// putDB 将连接放回连接池
func (re *RedisDB) putDB(name string, db interface{}) (err error) {
	if _, ok := re.connPool[name]; !ok {
		return errors.New("no redis connect")
	}
	err = re.connPool[name].Put(db)

	return
}

//  GetRedis 获取一个mysql db连接
func GetRedis(name string) (db redis.Conn,put func(), err error) {
	put = func() {}
	if redisConn == nil {
		return nil,put, errors.New("db connect is nil")
	}
	conn,put, err := redisConn.getDB(name)
	if err != nil {
		return nil,put, err
	}
	db = conn.(redis.Conn)
	return db, put,nil
}
